{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/home/tom/anaconda3/envs/pyvertical-dev/lib/python3.7/site-packages/tf_encrypted/operations/secure_random/secure_random_module_tf_1.15.3.so'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/tom/anaconda3/envs/pyvertical-dev/lib/python3.7/site-packages/tf_encrypted/session.py:24: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.append(\"..\" + os.sep + \"..\")\n",
    "\n",
    "import torch\n",
    "import syft as sy\n",
    "\n",
    "from src.future import PartitionedDataset, VerticalDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "hook = sy.TorchHook(torch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will now turn this dataset into a PartitionedDataset. PartitionedDatsets can hold data, targets or both."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.tensor([1.0, 2.0, 3.0]).tag(\"#toy\").describe(\"Toy input data.\")\n",
    "targets = torch.tensor([0, 1, 1]).tag(\"#toy\").describe(\"Toy data labels.\")\n",
    "dataset = PartitionedDataset(data, targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just one dataset isn't very exciting - the data is not vertically partitioned! PartitionedDatasets come with a helper method to vertically partition a dataset.\n",
    "We will move the data onto virtual workers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "alice = sy.VirtualWorker(id=\"alice\", hook=hook, is_client_worker=False)\n",
    "bob = sy.VirtualWorker(id=\"bob\", hook=hook, is_client_worker=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "vertical_data = dataset.vertically_federate((alice, bob))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "src.future.dataset.VerticalDataset"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(vertical_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This new dataset is a VerticalDataset. This is similar to syft's FederatedDataset - it holds a list of vertically partitioned dataset assigned to different workers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Toy input data.\n"
     ]
    }
   ],
   "source": [
    "alice_results = alice.search([\"#toy\"])\n",
    "for res in alice_results:\n",
    "    print(res.description)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Toy data labels.\n"
     ]
    }
   ],
   "source": [
    "bob_results = bob.search([\"#toy\"])\n",
    "for res in bob_results:\n",
    "    print(res.description)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see that Alice has the data and Bob has the labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['alice', 'bob']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vertical_data.workers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can collect a dataset from its remote."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "alices_dataset = vertical_data.get_dataset(\"alice\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PartitionedDataset\n",
       "\tData: tensor([1., 2., 3.])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "alices_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['bob']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vertical_data.workers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After which the VerticalDataset only contains Bob's labels."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
